import numpy as np
import random 
import numba
from numba import jit, njit

from . import weighted_path
from ..tinyhouse import InputOutput

try:
    profile
except:
    def profile(x): 
        return x
try:
    from numba.errors import NumbaDeprecationWarning, NumbaPendingDeprecationWarning
    import warnings

    warnings.simplefilter('ignore', category=NumbaDeprecationWarning)
    warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning)
except:
    pass

def get_genetic_map(chr_map) :
    genetic_map = []
    chr_map = np.array(chr_map)
    nChr = np.max(np.unique(chr_map))
    for val in range(nChr):
        mask = chr_map == (val + 1) # Since we start chromosomes at 0.
        start = np.argmax(mask)
        stop = len(mask) - np.argmax(mask[::-1])
        genetic_map.append((start, stop))        

    return genetic_map

@profile
def imputeIndividualFromPhenotypes_haploid(ind, sire, dam, qtls, chrMap):
    # print(ind.idx)
    # Yeah, the only thing we need from ind is their phenotype.
    # print("Running in haploid mode")
    qtl_effects = qtls
    genetic_map = get_genetic_map(chrMap)

    # For haploid mode we only use the first haplotype. But we'll double it, since the individual will inherit two copies.
    haplotypes = np.array([sire.haplotypes[0]*2, dam.haplotypes[0]*2], dtype = np.int8)

    target_phenotype = ind.phenotype

    ind.dosages = gibbs_sampler(target_phenotype, genetic_map, qtl_effects, haplotypes)


# def imputeIndividualFromPhenotypes(ind, sire, dam, qtls, chrMap):
#     # Yeah, the only thing we need from ind is their phenotype.
#     print("Running in diploid mode")

#     qtl_effects = qtls
#     genetic_map = get_genetic_map(chrMap)

#     paternal_haplotypes = np.array(sire.haplotypes, dtype = np.int8)
#     maternal_haplotypes = np.array(dam.haplotypes, dtype = np.int8)
#     target_phenotype = ind.phenotype

#     ind.dosages = gibbs_sampler(target_phenotype, genetic_map, qtl_effects, paternal_haplotypes, maternal_haplotypes)


@profile
def gibbs_sampler(target_phenotype, genetic_map, qtl_effects, paternal_haplotypes):

    sampling_information = Global_Information(genetic_map, qtl_effects, paternal_haplotypes)

    replicates = []
    n_replicates = 10
    for replicate in range(n_replicates):
        # print("replicate", replicate)
        # rep = run_gibbs_sampler(target_phenotype, sampling_information)
        rep = wrapper_numba_gibbs(target_phenotype, sampling_information)
        replicates.append(rep)

    nLoci = paternal_haplotypes.shape[-1]
    dosages = np.full(nLoci, 0, dtype = np.float32)
    for rep in replicates:
        add_rep_to_dosages(rep, sampling_information, dosages, len(replicates))

    return dosages

@profile
def run_gibbs_sampler(target_phenotype, sampling_information):

    state_information = Global_State(target_phenotype, sampling_information)

    n_iterations = 100
    for iteration in range(n_iterations):
        nChrom = len(sampling_information.chromosomes)
        for chrom in range(nChrom):
            state_information.sample(chrom)
        state_information.add_current_state()

    return state_information

def add_rep_to_dosages(rep, sampling_information, dosages, n_replicates):
    dosages += sampling_information.get_genotypes_from_state(rep)/n_replicates



def wrapper_numba_gibbs(target_phenotype, sampling_information):
    state_information = Global_State(target_phenotype, sampling_information)

    average_paths = [np.full(chrom_info.nLoci, 0, dtype = np.float32) for chrom_info in sampling_information.chromosomes]
    
    path_scores_list = tuple([chrom_info.path_scores for chrom_info in sampling_information.chromosomes])
    residual_means_list = tuple([chrom_info.residual_means for chrom_info in sampling_information.chromosomes])
    residual_variances_list = tuple([chrom_info.residual_variances for chrom_info in sampling_information.chromosomes])


    average_paths, path_counts = numba_run_gibbs_sampler(target_phenotype, average_paths, path_scores_list, residual_means_list, residual_variances_list)

    for i, chrom in enumerate(state_information.chromosome_states):
        chrom.average_path = average_paths[i]
        chrom.average_path_count = path_counts[i]

    return state_information

@jit(nopython=True, nogil = True)
def numba_run_gibbs_sampler(target_phenotype, average_paths, path_scores_list, residual_means_list, residual_variances_list):

    nChrom = len(average_paths)
    nTraits = len(target_phenotype)

    path_counts = np.full(nChrom, 0, dtype = np.float32)

    # Set initial phenotype.
    chrom_phenotype = np.full((nChrom, nTraits), 0, dtype = np.float32)

    current_phenotype = np.full(nTraits, 0, dtype = np.float32)

    # Adding in mean path information.
    for i in range(nChrom):
        chrom_phenotype[i,:] = .5*np.sum(path_scores_list[i][:,:, 0])
        chrom_phenotype[i,:] += .5*np.sum(path_scores_list[i][:,:, 1])

        current_phenotype += chrom_phenotype[i,:]

    n_iterations = 100
    for iteration in range(n_iterations):
        for chrom in range(nChrom):

            # Remove phenotype contribution.
            current_phenotype -= chrom_phenotype[chrom,:]


            # Setup variables and run path-sampling. 
            path_scores = path_scores_list[chrom]
            residual_means = residual_means_list[chrom]
            residual_variances = residual_variances_list[chrom]
            rec_rate = 1.0/len(average_paths[chrom])

            path, phenotype = weighted_path.sample_path(target_phenotype - current_phenotype, path_scores, residual_means, residual_variances, rec_rate)

            # Update chromosome variables. 
            average_paths[chrom] += path
            chrom_phenotype[chrom,:] = phenotype
            path_counts[chrom] += 1

            # Update phenotype.
            current_phenotype += chrom_phenotype[chrom,:]

    return average_paths, path_counts



class Global_Information(object) :
    # This is an object to hold information about all of the chromosomes

    def __init__(self, genetic_map, qtl_effects, paternal_haplotypes):
        # Just filling slots for now.

        self.chromosomes = []
        self.genetic_map = genetic_map
        for chrom in genetic_map:
            start, stop = chrom

            paternal_information = Chromosome_Information(qtl_effects[:, start:stop], paternal_haplotypes[:, start:stop])
            self.chromosomes.append(paternal_information)
        
        self.nLoci = paternal_haplotypes.shape[1]

    def get_genotypes_from_state(self, state):
        
        dosages = np.full(self.nLoci, 0, dtype = np.float32)

        for index, chrom in enumerate(self.genetic_map):
            start, stop = chrom
            dosages[start:stop] += self.chromosomes[index].convert_path_to_dosage(state.chromosome_states[index].get_average_path())
        return dosages


class Chromosome_Information(object):
    # This is an object to hold information about a particular chromosome.
    def __init__(self, qtl_effects, haplotypes):

        self.nTraits, self.nLoci = qtl_effects.shape
        self.parental_haplotypes = haplotypes
        self.path_scores = self.get_segregation_effects(haplotypes, qtl_effects)

        self.residual_means, self.residual_variances = weighted_path.set_mean_variance_all(self.path_scores, 1.0/self.nLoci)

    def get_segregation_effects(self, parental_haplotypes, qtl_effects):

        # The parental haplotypes are the parental haplotypes on a single chromosome.

        nTraits, nLoci = qtl_effects.shape
        path_scores = np.full((nLoci, nTraits, 2), 0, dtype = np.float32)
        
        for hap in range(2):
            # The multiplication here means that it's okay for haplotypes to be non-integers (or larger than 1).
            path_scores[:, :, hap] = parental_haplotypes[hap, :, None] * qtl_effects[:, :].T

        return path_scores

    def convert_path_to_dosage(self, path):
        # if path = 0 we want hap 0. if path = 1, we want hap 1.
        dosages = self.parental_haplotypes[0,:]*(1-path) + self.parental_haplotypes[1,:]*path
        return dosages

class Global_State(object):

    def __init__(self, target_phenotype, sampling_information):
        self.target_phenotype = target_phenotype

        self.current_phenotype = 0

        self.chromosome_states = []
        for chrom_info in sampling_information.chromosomes:
            state = Chrom_State(chrom_info)
            self.chromosome_states.append(state)
    @profile
    def sample(self, chrom):

        state = self.chromosome_states[chrom]
        old_phenotype = state.current_phenotype

        residual_phenotype = self.target_phenotype - self.current_phenotype + old_phenotype

        new_phenotype = state.sample(residual_phenotype)

        self.current_phenotype = self.current_phenotype - old_phenotype + new_phenotype

    def add_current_state(self):
        for state in self.chromosome_states:
            state.add_current_state()



class Chrom_State(object):
    # This is an object to sample/resample individual chromosomes.

    def __init__(self, chromosome_information):

        self.nLoci = chromosome_information.nLoci
        # self.sampling_function = weighted_path.Particle_Filter_Sampler

        self.chromosome_information = chromosome_information

        self.current_path = None
        self.current_phenotype = 0

        self.average_path = np.full(self.nLoci, 0, dtype = np.float32)
        self.average_path_count = 0


    @profile
    def sample(self, target_phenotype):
        path_scores = self.chromosome_information.path_scores
        residual_means = self.chromosome_information.residual_means
        residual_variances = self.chromosome_information.residual_variances
        rec_rate = 1.0/self.nLoci
        path, phenotype = weighted_path.sample_path(target_phenotype, path_scores, residual_means, residual_variances, rec_rate)

        self.current_path = path
        self.current_phenotype = phenotype
    
        # np.savetxt("residual_means_0.txt", residual_means[:,:,0].T, "%4f")
        # np.savetxt("residual_means_1.txt", residual_means[:,:,1].T, "%4f")

        # np.savetxt("residual_variances_0.txt", residual_variances[:,:,0].T, "%4f")
        # np.savetxt("residual_variances_1.txt", residual_variances[:,:,1].T, "%4f")

        return self.current_phenotype

    def add_current_state(self):
        self.average_path += self.current_path
        self.average_path_count += 1

    def get_average_path(self):
        return self.average_path/self.average_path_count












